application: targetted data collection

knowing what we know, where and when should we plan to next collect data?

planning the next test

survival analysis

Code
library(cmdstanr)

survival_model <- cmdstan_model(stan_file = "survival.stan")
survival_model$format()
data {
  int<lower=0> n_meas; // number of observations
  vector<lower=0>[n_meas] obs_time; // time of observation
  vector<lower=0>[n_meas] fail_lb; // lower bound of failure time
  vector<lower=0>[n_meas] fail_ub; // status of observation
  
  array[n_meas] int<lower=0, upper=1> fail_status; // if a failure has occured, we have interval-censored data
  
  int<lower=0> n_pred; // number of predictions
  vector<lower=0>[n_pred] pred_time; // time of prediction
}
parameters {
  real<lower=0> scale; // scale parameter
  real<lower=0> shape; // shape parameter
}
model {
  //priors
  scale ~ normal(8, 3);
  shape ~ normal(6, 3);
  
  //likelihood
  for (n in 1 : n_meas) {
    if (fail_status[n] == 0) {
      target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      target += log(loglogistic_cdf(fail_ub[n] | scale, shape)
                    - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
}
generated quantities {
  vector[n_meas] log_lik;
  vector[n_pred] p_fail_pred;
  
  for (n in 1 : n_meas) {
    if (fail_status[n] == 1) {
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(loglogistic_cdf(fail_ub[n] | scale, shape)
                       - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
  
  for (n in 1 : n_pred) {
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
}
Code
import cmdstanpy

survival_model = cmdstanpy.CmdStanModel(stan_file = "survival.stan")
INFO:cmdstanpy:found newer exe file, not recompiling
Code
stan_code = survival_model.code()

from pygments import highlight
from pygments.lexers import StanLexer
from pygments.formatters import NullFormatter

formatted_stan_code = highlight(stan_code, StanLexer(), NullFormatter())

print(formatted_stan_code)
data {
  int <lower = 0> n_meas;                   // number of observations
  vector <lower = 0> [n_meas] obs_time;     // time of observation
  vector <lower = 0> [n_meas] fail_lb;      // lower bound of failure time
  vector <lower = 0> [n_meas] fail_ub;      // status of observation

  array [n_meas] int<lower = 0, upper = 1> fail_status; // if a failure has occured, we have interval-censored data

  int <lower = 0> n_pred;                   // number of predictions
  vector <lower = 0> [n_pred] pred_time;    // time of prediction
}

parameters {
  real <lower = 0> scale; // scale parameter
  real <lower = 0> shape; // shape parameter
}

model{
    //priors
    scale ~ normal(8, 3);
    shape ~ normal(6, 3);

    //likelihood
    for(n in 1:n_meas){
        if(fail_status[n] == 0){
            target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
        } else {
            target += log(
                          loglogistic_cdf(fail_ub[n] | scale, shape) - 
                          loglogistic_cdf(fail_lb[n] | scale, shape)
                        );
        }
    }
}

generated quantities {
  vector [n_meas] log_lik;
  vector [n_pred] p_fail_pred;

  for(n in 1:n_meas){
    if(fail_status[n] == 1){
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(
                        loglogistic_cdf(fail_ub[n] | scale, shape) - 
                        loglogistic_cdf(fail_lb[n] | scale, shape)
                      );
    }
  }

  for(n in 1:n_pred){
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
  
}
Code
using Turing
using LogExpFunctions: log1mexp

include("../../data/LogLogisticDistribution.jl")
LogLogisticDistribution (generic function with 1 method)
Code

@model function loglogistic_survival(
    obs_time::Vector{Float64},     # time of observation
    fail_lb::Vector{Float64},      # lower bound of failure time
    fail_ub::Vector{Float64},      # upper bound of failure time
    fail_status::Vector{Int}   # 0 if right-censored, 1 if interval-censored
)
    # Priors
    scale ~ Normal(8, 3) |> d -> truncated(d, lower = 0)
    shape ~ Normal(6, 3) |> d -> truncated(d, lower = 0)

    # Create distribution with current parameters
    d = LogLogisticDistribution(scale, shape)

    # Likelihood
    for i in eachindex(obs_time)
        if fail_status[i] == 0
            # Right censored: P(T > obs_time)
            Turing.@addlogprob! log(survival(d, obs_time[i]))
        else
            # Interval censored: P(lb < T < ub)
            Turing.@addlogprob! log(
                cdf(d, fail_ub[i]) - cdf(d, fail_lb[i])
            )
        end
    end
end
loglogistic_survival (generic function with 2 methods)

survival analysis

Code
library(tidyverse)

failure_data <- read_csv("../../data/failures.csv")

model_data <- list(
  n_meas = nrow(failure_data),
  obs_time = rep(12, nrow(failure_data)),
  fail_lb = failure_data$fail_lb,
  fail_ub = failure_data$fail_ub,
  fail_status = is.finite(failure_data$fail_ub) |> as.integer(),
  n_pred = 100,
  pred_time = seq(from = 0, to = 20, length.out = 100)
)

survival_fit <- survival_model$sample(
  data = model_data,
  chains = 4,
  parallel_chains = parallel::detectCores(),
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
Running MCMC with 4 chains, at most 16 in parallel...

Chain 1 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 1 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 1 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 1 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 1 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 1 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 1 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 1 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 1 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 1 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 1 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 1 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 1 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 1 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 1 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 1 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 1 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 1 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 1 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 1 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 1 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 1 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 1 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 1 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 1 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 1 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 1 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 1 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 2 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 2 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 2 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 2 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 2 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 2 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 2 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 2 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 2 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 2 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 2 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 2 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 2 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 2 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 2 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 2 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 2 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 2 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 2 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 2 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 2 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 2 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 2 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 2 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 2 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 2 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 2 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 2 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 2 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 2 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 2 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 2 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 2 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 3 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 3 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 3 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 3 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 3 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 3 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 3 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 3 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 3 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 3 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 3 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 3 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 3 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 3 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 3 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 3 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 3 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 3 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 3 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 3 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 3 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 3 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 3 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 3 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 3 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 3 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 3 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 3 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 3 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 3 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 3 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 3 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 3 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 3 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 3 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 4 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 4 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 4 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 4 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 4 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 4 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 4 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 4 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 4 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 4 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 4 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 4 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 4 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 4 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 4 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 4 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 4 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 4 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 4 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 4 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 4 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 4 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 4 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 4 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 4 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 4 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 4 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 4 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 4 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 4 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 4 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 4 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 4 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 4 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 4 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 4 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 1 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 1 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 1 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 1 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 1 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 1 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 1 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 1 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 1 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 1 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 1 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 1 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 1 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 1 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 2 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 2 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 2 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 2 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 2 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 2 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 2 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 2 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 2 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 3 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 3 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 3 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 3 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 3 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 3 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 3 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 4 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 4 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 4 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 4 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 4 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 4 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 1 finished in 0.3 seconds.
Chain 2 finished in 0.3 seconds.
Chain 3 finished in 0.3 seconds.
Chain 4 finished in 0.3 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.3 seconds.
Total execution time: 0.5 seconds.
Code
survival_fit$summary()
# A tibble: 123 × 10
   variable     mean median    sd   mad     q5    q95  rhat ess_bulk ess_tail
   <chr>       <dbl>  <dbl> <dbl> <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__       -33.0  -32.7  1.04  0.747 -35.0  -32.0   1.00    3431.    4734.
 2 scale        9.48   9.45 0.693 0.678   8.38  10.6   1.00    5912.    4789.
 3 shape        5.62   5.57 1.09  1.10    3.93   7.50  1.00    5308.    4987.
 4 log_lik[1]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 5 log_lik[2]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 6 log_lik[3]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 7 log_lik[4]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 8 log_lik[5]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 9 log_lik[6]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
10 log_lik[7]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
# ℹ 113 more rows
Code
import polars as pl, numpy as np
import multiprocessing

failure_data = pl.read_csv("../../data/failures.csv").with_columns([
    pl.col("fail_ub").cast(pl.Float64),
    pl.col("fail_lb").cast(pl.Float64)
])

# Define a large finite number to substitute for infinity.
large_num = 1e10

# Convert the fail_ub array from the failure_data, and replace inf values.
fail_ub = failure_data["fail_ub"].to_numpy().copy()
fail_ub[~np.isfinite(fail_ub)] = large_num

# Prepare your model_data dictionary.
model_data = {
    "n_meas": failure_data.shape[0],
    "obs_time": [12] * failure_data.shape[0],
    "fail_lb": failure_data["fail_lb"].to_numpy(),
    "fail_ub": fail_ub,
    "fail_status": (failure_data["fail_ub"].is_finite().cast(pl.Int64)).to_numpy(),
    "n_pred": 100,
    "pred_time": np.linspace(start=0, stop=20, num=100)
}

survival_fit = survival_model.sample(
  data = model_data,
  chains = 4,
  parallel_chains = 1,
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan start processing

chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status


chain 3 |          | 00:00 Status



chain 4 |          | 00:00 Status
chain 1 |#####4    | 00:00 Iteration: 2001 / 4000 [ 50%]  (Sampling)
chain 1 |#########7| 00:00 Iteration: 3800 / 4000 [ 95%]  (Sampling)

chain 2 |2         | 00:00 Status

chain 2 |######1   | 00:00 Iteration: 2300 / 4000 [ 57%]  (Sampling)

chain 2 |#########5| 00:00 Iteration: 3700 / 4000 [ 92%]  (Sampling)


chain 3 |2         | 00:00 Status


chain 3 |######1   | 00:00 Iteration: 2300 / 4000 [ 57%]  (Sampling)


chain 3 |#########5| 00:00 Iteration: 3700 / 4000 [ 92%]  (Sampling)



chain 4 |2         | 00:01 Status



chain 4 |######1   | 00:01 Iteration: 2300 / 4000 [ 57%]  (Sampling)



chain 4 |#########5| 00:01 Iteration: 3700 / 4000 [ 92%]  (Sampling)
chain 1 |##########| 00:01 Sampling completed                       

chain 2 |##########| 00:01 Sampling completed                       

chain 3 |##########| 00:01 Sampling completed                       

chain 4 |##########| 00:01 Sampling completed                       
INFO:cmdstanpy:CmdStan done processing.
Code
survival_fit.summary()
                   Mean     MCSE  StdDev     5%  ...   95%   N_Eff  N_Eff/s  R_hat
name                                             ...                              
lp__             -33.00  0.01800   1.000 -35.00  ... -32.0  3500.0   3900.0    1.0
scale              9.50  0.00910   0.690   8.40  ...  11.0  5800.0   6500.0    1.0
shape              5.60  0.01500   1.100   3.90  ...   7.5  5400.0   6000.0    1.0
log_lik[1]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0   6425.0    1.0
log_lik[2]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0   6425.0    1.0
...                 ...      ...     ...    ...  ...   ...     ...      ...    ...
p_fail_pred[96]    0.97  0.00033   0.022   0.93  ...   1.0  4457.0   4980.0    1.0
p_fail_pred[97]    0.98  0.00032   0.021   0.93  ...   1.0  4452.0   4975.0    1.0
p_fail_pred[98]    0.98  0.00031   0.021   0.94  ...   1.0  4448.0   4970.0    1.0
p_fail_pred[99]    0.98  0.00030   0.020   0.94  ...   1.0  4444.0   4965.0    1.0
p_fail_pred[100]   0.98  0.00029   0.019   0.94  ...   1.0  4441.0   4962.0    1.0

[123 rows x 9 columns]
Code
using CSV, DataFrames

failure_data = CSV.read("../../data/failures.csv", DataFrame)

survival_fit = loglogistic_survival(
    repeat([12.0], nrow(failure_data)),
    failure_data.fail_lb,
    failure_data.fail_ub,
    isfinite.(failure_data.fail_ub) |> x -> Int.(x)
) |> model -> sample(MersenneTwister(231123), model, NUTS(), MCMCThreads(), 2000, 4)
Chains MCMC chain (2000×14×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 15.31 seconds
Compute duration  = 11.92 seconds
parameters        = scale, shape
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

       scale    9.4688    0.6886    0.0090   5854.0791   4655.8635    1.0006   ⋯
       shape    5.6014    1.0963    0.0143   5854.6561   5483.7428    1.0012   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       scale    8.1597    9.0166    9.4565    9.9064   10.9001
       shape    3.6211    4.8263    5.5506    6.3210    7.8421

expected information gain

expected information gain

  • quantify uncertainty in posterior predictions
  • identify prospetive data collection options
  • generate all possible outcome scenarios
    • here (helpfully): failure or no failure
  • for each outcome:
    • simpulate the data collection and re-fit the model
    • quantify uncertainty in the new posterior predictions
    • find the difference (reduction in uncertainty with the new data)
    • weight the reduction by the probability of the outcome
  • compare the expected “information gain” to rank order data collection options

measures of uncertainty

  • entropy?
  • log-likelihood?
  • kernel density estimation?
  • variance?
Code
post_pred |> head()
# A tibble: 6 × 5
  Parameter      Chain Iteration value  time
  <chr>          <int>     <int> <dbl> <dbl>
1 p_fail_pred[1]     1         1     0     0
2 p_fail_pred[1]     1         2     0     0
3 p_fail_pred[1]     1         3     0     0
4 p_fail_pred[1]     1         4     0     0
5 p_fail_pred[1]     1         5     0     0
6 p_fail_pred[1]     1         6     0     0
Code
estimate_uncertainty <- function(posterior = post_pred) {
  posterior |>
    group_by(time) |>
    summarise(uncertainty_base = var(value))
}

estimate_uncertainty() |> head()
# A tibble: 6 × 2
   time uncertainty_base
  <dbl>            <dbl>
1 0             0       
2 0.202         4.57e-12
3 0.404         1.39e-10
4 0.606         1.16e- 9
5 0.808         5.65e- 9
6 1.01          2.05e- 8
Code
post_pred.head()
shape: (5, 5)
Chain Iteration Parameter value time
i64 i64 str f64 f64
1 1 "p_fail_pred[1]" 0.0 0.0
1 2 "p_fail_pred[1]" 0.0 0.0
1 3 "p_fail_pred[1]" 0.0 0.0
1 4 "p_fail_pred[1]" 0.0 0.0
1 5 "p_fail_pred[1]" 0.0 0.0
Code
def estimate_uncertainty(posterior=post_pred):
    # In Polars, we need to use pl.col for column references
    return (posterior
            .group_by("time")
            .agg(uncertainty=pl.col("value").var())
            .sort("time"))

estimate_uncertainty().head()
shape: (5, 2)
time uncertainty
f64 f64
0.0 0.0
0.20202 4.5745e-12
0.40404 1.3886e-10
0.606061 1.1557e-9
0.808081 5.6478e-9

expected information gain

Code
estimate_information_gain <- function(proposed_time) {
  # we need new datasets (hypothesising our next data point)
  fail_data <- model_data -> no_fail_data
  
  # case A: we observe a failure
  fail_data$n_meas <- fail_data$n_meas + 1
  fail_data$obs_time <- c(fail_data$obs_time, proposed_time)
  fail_data$fail_lb <- c(fail_data$fail_lb, proposed_time - 1.5)
  fail_data$fail_ub <- c(fail_data$fail_ub, proposed_time)
  fail_data$fail_status <- c(fail_data$fail_status, 1)

  # case B: we do not observe a failure
  no_fail_data$n_meas <- no_fail_data$n_meas + 1
  no_fail_data$obs_time <- c(no_fail_data$obs_time, proposed_time)
  no_fail_data$fail_lb <- c(no_fail_data$fail_lb, proposed_time)
  no_fail_data$fail_ub <- c(no_fail_data$fail_ub, Inf)
  no_fail_data$fail_status <- c(no_fail_data$fail_status, 0)

  # re-fitting our models for each possible outcome
  fail_fit <- survival_model$sample(
    data = fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  no_fail_fit <- survival_model$sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  # quantify uncertainty in the new predictions
  base_uncertainties <- estimate_uncertainty()
    
  fail_uncertainties <- fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = fail_fit$metadata()$iter_sampling * length(fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_fail = uncertainty_base)
    
  no_fail_uncertainties <- no_fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = no_fail_fit$metadata()$iter_sampling * length(no_fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_no_fail = uncertainty_base)
    
  # what are the prior probabilities of each outcome?
  p_fail <- post_pred |>
    filter(abs(time - proposed_time) == min(abs(time - proposed_time))) |>
    summarise(p = mean(value)) |>
    pull(p)
    
  information_gains <- base_uncertainties |>
    left_join(fail_uncertainties, by = "time") |>
    left_join(no_fail_uncertainties, by = "time") |>
    mutate(
      # calculate a weighted uncertainty reduction
      weighted_reduction = pmax(0, (uncertainty_base - uncertainty_fail)) * p_fail +
                           pmax(0, (uncertainty_base - uncertainty_no_fail)) * (1 - p_fail)

    )
    
  # return the expected information gain
  return(information_gains$weighted_reduction |> sum())
}
Code
import copy

def estimate_information_gain(proposed_time):
  fail_data = copy.deepcopy(model_data)
  no_fail_data = copy.deepcopy(model_data)
  
  fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  no_fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  no_fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  no_fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  no_fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  fail_data["n_meas"] = model_data["n_meas"] + 1
  fail_data["obs_time"].append(proposed_time)
  fail_data["fail_lb"].append(proposed_time - 1.5)
  fail_data["fail_ub"].append(proposed_time)
  fail_data["fail_status"].append(1)

  no_fail_data["n_meas"] = model_data["n_meas"] + 1
  no_fail_data["obs_time"].append(proposed_time)
  no_fail_data["fail_lb"].append(proposed_time)
  no_fail_data["fail_ub"].append(large_num)  
  no_fail_data["fail_status"].append(0)
    
  fail_fit = survival_model.sample(
    data = fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
  
  no_fail_fit = survival_model.sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
    
  window = 2.0
  
  base_uncertainties = (
    post_pred
    .filter(abs(pl.col("time") - proposed_time) <= window)
    .group_by("time")
    .agg(uncertainty_base=pl.col("value").var())
    .sort("time")
  )

  fail_post = (
    process_mcmc_draws(fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_fail"))
    .sort("time")
  )
  
  no_fail_post = (
    process_mcmc_draws(no_fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_no_fail"))
    .sort("time")
  )
    
  min_diff = (
    post_pred
    .select((pl.col("time") - proposed_time).abs().alias("diff"))
    .select(pl.col("diff").min())
    .item()
  )
    
  p_fail = (
    post_pred
    .filter((pl.col("time") - proposed_time).abs() == min_diff)
    .select(pl.col("value").mean().alias("p"))
    .item()
  )
    
  information_gains = (
    base_uncertainties
    .join(fail_post, on="time", how="left")
    .join(no_fail_post, on="time", how="left")
    .with_columns(
        weighted_reduction=(
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_fail"))
              .otherwise(0) * p_fail +
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail"))
              .otherwise(0) * (1 - p_fail)
        )
    )
  )
    
  # Return the total information gain (sum over weighted_reduction)
  total_gain = information_gains.select(pl.col("weighted_reduction")).sum().item()
  return total_gain

expected information gain

experimental design

  • what do we want to achieve with data collection?
    • reduce uncertainty in predictions?
    • test a hypothesis?
    • support decision-making? (see “value of information analysis”)

break?